{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false
      },
      "source": [
        "# 优化大语言模型\n",
        "\n",
        "参考：[optimize_llm](https://tvm.apache.org/docs/how_to/tutorials/optimize_llm.html)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "随着大型语言模型(LLMs)在众多不同的领域成为热门的研究方向，将它们部署在云端和边缘设备上成为了具有挑战性的任务。在本教程中，将演示如何使用 Apache TVM 来优化大语言模型。使用来自 Hugging Face 的预训练 TinyLlama 模型，并在不同的设备上进行部署。\n",
        "\n",
        "整体流程包括以下步骤：\n",
        "\n",
        "- **构建或导入模型**：构建一个神经网络模型，或者从其他框架（如 PyTorch、ONNX）导入预训练的模型，并创建 TVM IRModule，其中包含编译所需的所有信息，包括用于计算图的高级别 Relax 函数和用于张量程序的低级 TensorIR 函数。\n",
        "- **执行可组合优化**：执行一系列优化转换，例如图优化、张量程序优化和库调度。\n",
        "- **构建和通用部署**：将优化后的模型构建为可在通用运行时部署的模块，并在不同设备上执行，如 CPU、GPU 或其他加速器。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 构建模型架构\n",
        "\n",
        "将使用来自 Hugging Face 的预训练 TinyLlama 模型。然而，通常只加载来自 Hugging Face 的预训练权重，而不加载模型架构。需要自己构建模型架构。Apache TVM 准备了类似 PyTorch 的 API 来构建模型架构。可以使用这个 API 来构建模型架构。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/_optional_torch_c_dlpack.py:409: UserWarning: Failed to load torch c dlpack extension: Error building extension 'c_dlpack': [1/2] /media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1018\\\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o \n",
            "\u001b[31mFAILED: [code=1] \u001b[0mmain.o \n",
            "/media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1018\\\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o \n",
            "In file included from /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp:8:\n",
            "/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/c10/cuda/CUDAStream.h:3:10: fatal error: cuda_runtime_api.h: No such file or directory\n",
            "    3 | #include <cuda_runtime_api.h>\n",
            "      |          ^~~~~~~~~~~~~~~~~~~~\n",
            "compilation terminated.\n",
            "ninja: build stopped: subcommand failed.\n",
            ",EnvTensorAllocator will not be enabled.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "os.environ['PATH'] += ':/usr/local/cuda/bin' # 保证 nvcc 可以被找到\n",
        "import dataclasses\n",
        "import enum\n",
        "import os\n",
        "from pathlib import Path\n",
        "from pprint import pprint\n",
        "from typing import List, Optional\n",
        "\n",
        "import tvm\n",
        "from tvm import dlight, relax, te, tir\n",
        "from tvm.relax import register_pipeline\n",
        "from tvm.relax.frontend import nn\n",
        "from tvm.relax.frontend.nn import Tensor, op\n",
        "from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, TIRPagedKVCache\n",
        "from tvm.runtime import ShapeTuple"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "首先，需要定义模型配置。配置包括模型的关键参数，如隐藏层大小、中间层大小等。为了方便起见，特别为 TinyLlama 模型定义了常量配置。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@dataclasses.dataclass\n",
        "class LlamaConfig:\n",
        "    hidden_size: int = 2048\n",
        "    intermediate_size: int = 5632\n",
        "    num_attention_heads: int = 32\n",
        "    num_hidden_layers: int = 22\n",
        "    rms_norm_eps: float = 1e-05\n",
        "    vocab_size: int = 32000\n",
        "    rope_theta: int = 10000\n",
        "    context_window_size: int = 2048\n",
        "    prefill_chunk_size: int = 2048\n",
        "    num_key_value_heads: int = 4\n",
        "    head_dim: int = 64  # hidden_size // num_attention_heads\n",
        "\n",
        "\n",
        "dev = tvm.device(\"cuda\", 0)\n",
        "target = tvm.target.Target.from_device(dev)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "接下来，定义 Paged KV 缓存的 RoPE 模式。RoPE 模式用于对查询和键张量应用相对位置编码（RoPE）。RoPE 模式可以设置为 `NONE`、`NORMAL` 或 `INLINE`。如果 RoPE 模式为 `NONE`，KV 缓存将不会对查询和键张量应用 RoPE。如果 RoPE 模式为 `NORMAL`，在将键张量添加到缓存之前，会对键张量应用 RoPE。如果 RoPE 模式为 `INLINE`，注意力内核会即时对查询和键张量应用 RoPE。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class RopeMode(enum.IntEnum):\n",
        "    \"\"\"The RoPE mode of the Paged KV cache.\n",
        "    If it is none, the KV cache will not apply RoPE to q and k.\n",
        "    If it is normal, RoPE will be applied to k before adding k to cache.\n",
        "    Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.\n",
        "    \"\"\"\n",
        "\n",
        "    NONE = 0\n",
        "    NORMAL = 1\n",
        "    INLINE = 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "其次，定义模型架构。模型架构由三部分组成：\n",
        "\n",
        "- 嵌入层：嵌入层将输入的token ID转换为隐藏状态。\n",
        "- 解码器层：解码器层是模型的核心。每个解码器层由一个自注意力层和一个前馈网络（feed-forward network，FFN）层组成。\n",
        "- 输出层：输出层将隐藏状态转换为logits。\n",
        "\n",
        "首先我们定义 FFN 层。请注意，下面的 FFN 层是优化实现，我们将 `gate` 和 `up` projection 融合到一个内核中。FFN 层的原始实现是：`FFN(x) = down_proj(silu(gate(x)) * up(x))` 我们可以将 `gate` 和 `up` projection 结合到一个内核中以获得更好的性能。优化后的实现是：\n",
        "```python\n",
        "concat_x = gate_up(x)\n",
        "gate_x, up_x = split(concat_x, 2, axis=-1)\n",
        "FFN(x) = down_proj(silu(gate_x) * up_x)\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class LlamaFFN(nn.Module):\n",
        "    def __init__(self, config: LlamaConfig):\n",
        "        super().__init__()\n",
        "        self.gate_up_proj = nn.Linear(\n",
        "            in_features=config.hidden_size,\n",
        "            out_features=2 * config.intermediate_size,\n",
        "            bias=False,\n",
        "        )\n",
        "        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)\n",
        "\n",
        "    def forward(self, x: Tensor):\n",
        "        concat_x1_x2 = self.gate_up_proj(x)\n",
        "        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n",
        "        return self.down_proj(op.silu(x1) * x2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "然后我们定义 **自注意层**（self-attention layer）。自注意层由三部分组成：\n",
        "\n",
        "- **QKV 投影**：QKV 投影将输入的隐藏状态转换为 `query`、`key` 和 `value` 张量。\n",
        "- 注意力： 注意力层计算注意力分数并应用 `softmax` 运算。\n",
        "- 输出投影：输出投影将注意力输出转换为隐藏状态。\n",
        "\n",
        "我们对自注意层的不同部分进行优化：\n",
        "\n",
        "- QKV 投影：我们利用 QKV 投影上的水平融合，并将它们融合为一个 内核。\n",
        "- 注意力： 我们利用 attention 的水平融合，融合 QKV 投影和"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class LlamaAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n",
        "    def __init__(self, config: LlamaConfig):\n",
        "        self.head_dim = config.head_dim\n",
        "        self.num_q_heads = config.num_attention_heads\n",
        "        self.num_kv_heads = config.num_key_value_heads\n",
        "        # horizontal fusion on QKV projection\n",
        "        self.qkv_proj = nn.Linear(\n",
        "            in_features=config.hidden_size,\n",
        "            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n",
        "            bias=False,\n",
        "        )\n",
        "        self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n",
        "\n",
        "    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n",
        "        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n",
        "        b, s, _ = hidden_states.shape\n",
        "        # QKV Projection\n",
        "        qkv = self.qkv_proj(hidden_states)\n",
        "        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n",
        "        # Attention\n",
        "        output = op.reshape(\n",
        "            paged_kv_cache.attention_with_fused_qkv(\n",
        "                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n",
        "            ),\n",
        "            (b, s, h_q * d),\n",
        "        )\n",
        "        # Output Projection\n",
        "        return self.o_proj(output)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "最后，使用 FFN 和自注意力层定义模型架构。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class LlamaDecoderLayer(nn.Module):\n",
        "    def __init__(self, config: LlamaConfig):\n",
        "        rms_norm_eps = config.rms_norm_eps\n",
        "        self.self_attn = LlamaAttention(config)\n",
        "        self.mlp = LlamaFFN(config)\n",
        "        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n",
        "        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n",
        "\n",
        "    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n",
        "        hidden_states += self.self_attn(\n",
        "            self.input_layernorm(hidden_states), paged_kv_cache, layer_id\n",
        "        )\n",
        "        hidden_states += self.mlp(self.post_attention_layernorm(hidden_states))\n",
        "        return hidden_states\n",
        "\n",
        "\n",
        "class LlamaModel(nn.Module):\n",
        "    def __init__(self, config: LlamaConfig):\n",
        "        assert config.hidden_size % config.num_attention_heads == 0\n",
        "        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n",
        "        self.layers = nn.ModuleList(\n",
        "            [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n",
        "        )\n",
        "        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n",
        "\n",
        "    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n",
        "        hidden_states = input_embed\n",
        "        for layer_id, layer in enumerate(self.layers):\n",
        "            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n",
        "        hidden_states = self.norm(hidden_states)\n",
        "        return hidden_states\n",
        "\n",
        "\n",
        "class LlamaForCasualLM(nn.Module):\n",
        "    def __init__(self, config: LlamaConfig):\n",
        "        self.model = LlamaModel(config)\n",
        "        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n",
        "        self.num_hidden_layers = config.num_hidden_layers\n",
        "        self.num_attention_heads = config.num_attention_heads\n",
        "        self.num_key_value_heads = config.num_key_value_heads\n",
        "        self.head_dim = config.head_dim\n",
        "        self.hidden_size = config.hidden_size\n",
        "        self.vocab_size = config.vocab_size\n",
        "        self.rope_theta = config.rope_theta\n",
        "        self.dtype = \"float32\"\n",
        "\n",
        "    def to(self, dtype: Optional[str] = None):\n",
        "        super().to(dtype=dtype)\n",
        "        if dtype is not None:\n",
        "            self.dtype = dtype\n",
        "\n",
        "    def embed(self, input_ids: Tensor):\n",
        "        return self.model.embed_tokens(input_ids)\n",
        "\n",
        "    def get_logits(self, hidden_states: Tensor):\n",
        "        logits = self.lm_head(hidden_states)\n",
        "        if logits.dtype != \"float32\":\n",
        "            logits = logits.astype(\"float32\")\n",
        "        return logits\n",
        "\n",
        "    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n",
        "        def _index(x: te.Tensor):  # x[:-1,:]\n",
        "            b, s, d = x.shape\n",
        "            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n",
        "\n",
        "        hidden_states = self.model(input_embed, paged_kv_cache)\n",
        "        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n",
        "        logits = self.get_logits(hidden_states)\n",
        "        return logits, paged_kv_cache\n",
        "\n",
        "    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n",
        "        hidden_states = self.model(input_embed, paged_kv_cache)\n",
        "        logits = self.get_logits(hidden_states)\n",
        "        return logits, paged_kv_cache\n",
        "\n",
        "    def create_tir_paged_kv_cache(\n",
        "        self,\n",
        "        max_batch_size: tir.Var,\n",
        "        max_total_seq_len: tir.Var,\n",
        "        prefill_chunk_size: tir.Var,\n",
        "        page_size: tir.Var,\n",
        "    ) -> PagedKVCache:\n",
        "        return TIRPagedKVCache(\n",
        "            attn_kind=\"mha\",\n",
        "            max_batch_size=max_batch_size,\n",
        "            max_total_seq_len=max_total_seq_len,\n",
        "            prefill_chunk_size=prefill_chunk_size,\n",
        "            page_size=page_size,\n",
        "            support_sliding_window=0,\n",
        "            layer_partition=relax.ShapeExpr([0, self.num_hidden_layers]),\n",
        "            num_hidden_layers=self.num_hidden_layers,\n",
        "            num_attention_heads=self.num_attention_heads,\n",
        "            num_key_value_heads=self.num_key_value_heads,\n",
        "            qk_head_dim=self.head_dim,\n",
        "            v_head_dim=self.head_dim,\n",
        "            mla_original_qk_head_dim=0,\n",
        "            mla_original_v_head_dim=0,\n",
        "            rope_mode=RopeMode.NORMAL,\n",
        "            rope_scale=1,\n",
        "            rope_theta=self.rope_theta,\n",
        "            rope_scaling={},\n",
        "            rope_ext_factors=relax.PrimValue(0),\n",
        "            rotary_dim=self.head_dim,\n",
        "            dtype=self.dtype,\n",
        "            target=target,\n",
        "            enable_disaggregation=False,\n",
        "        )\n",
        "\n",
        "    def get_default_spec(self):\n",
        "        mod_spec = {\n",
        "            \"embed\": {\n",
        "                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n",
        "                \"$\": {\n",
        "                    \"param_mode\": \"packed\",\n",
        "                    \"effect_mode\": \"none\",\n",
        "                },\n",
        "            },\n",
        "            \"prefill\": {\n",
        "                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n",
        "                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n",
        "                \"$\": {\n",
        "                    \"param_mode\": \"packed\",\n",
        "                    \"effect_mode\": \"none\",\n",
        "                },\n",
        "            },\n",
        "            \"decode\": {\n",
        "                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n",
        "                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n",
        "                \"$\": {\n",
        "                    \"param_mode\": \"packed\",\n",
        "                    \"effect_mode\": \"none\",\n",
        "                },\n",
        "            },\n",
        "            \"create_tir_paged_kv_cache\": {\n",
        "                \"max_batch_size\": int,\n",
        "                \"max_total_seq_len\": int,\n",
        "                \"prefill_chunk_size\": int,\n",
        "                \"page_size\": int,\n",
        "                \"$\": {\n",
        "                    \"param_mode\": \"none\",\n",
        "                    \"effect_mode\": \"none\",\n",
        "                },\n",
        "            },\n",
        "        }\n",
        "        return nn.spec.ModuleSpec.from_raw(mod_spec, self)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 将模型导出为 Relax IRModule\n",
        "定义模型架构后，我们可以将模型导出为 Relax IRModule。为了演示，我们只展示了模型架构的一部分和参数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@R.function\n",
            "def prefill(input_embed: R.Tensor((1, \"seq_len\", 2048), dtype=\"float16\"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32000, 2048), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2560, 2048), dtype=\"float16\"), R.Tensor((2048, 2048), dtype=\"float16\"), R.Tensor((11264, 2048), dtype=\"float16\"), R.Tensor((2048, 5632), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((2048,), dtype=\"float16\"), R.Tensor((32000, 2048), dtype=\"float16\"))) -> R.Tuple(R.Tensor((1, 1, 32000), dtype=\"float32\"), R.Object):\n",
            "    seq_len = T.int64()\n",
            "    R.func_attr({\"num_input\": 2})\n",
            "    with R.dataflow():\n",
            "        model_embed_tokens_weight1: R.Tensor((32000, 2048), dtype=\"float16\") = packed_params[0]\n",
            "        model_layers_0_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype=\"float16\") = packed_params[1]\n",
            "        model_layers_0_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype=\"float16\") = packed_params[2]\n",
            "        model_layers_0_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype=\"float16\") = packed_params[3]\n",
            "        model_layers_0_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype=\"float16\") = packed_params[4]\n",
            "        model_layers_0_input_layernorm_weight1: R.Tensor((2048,), dtype=\"float16\") = packed_params[5]\n",
            "        model_layers_0_post_attention_layernorm_weight1: R.Tensor((2048,), dtype=\"float16\") = packed_params[6]\n",
            "        model_layers_1_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype=\"float16\") = packed_params[7]\n",
            "        model_layers_1_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype=\"float16\") = packed_params[8]\n",
            "        model_layers_1_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype=\"float16\") = packed_params[9]\n",
            "        model_layers_1_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype=\"float16\") = packed_params[10]\n",
            "        model_layers_1_input_layernorm_weight1: R.Tensor((2048,), dtype=\"float16\") = packed_params[11]\n",
            "        ...\n",
            "\n",
            "Parameters:\n",
            "[('model.embed_tokens.weight', Tensor([32000, 2048], \"float16\")),\n",
            " ('model.layers.0.self_attn.qkv_proj.weight', Tensor([2560, 2048], \"float16\")),\n",
            " ('model.layers.0.self_attn.o_proj.weight', Tensor([2048, 2048], \"float16\")),\n",
            " ('model.layers.0.mlp.gate_up_proj.weight', Tensor([11264, 2048], \"float16\")),\n",
            " ('model.layers.0.mlp.down_proj.weight', Tensor([2048, 5632], \"float16\"))]\n"
          ]
        }
      ],
      "source": [
        "model_config = LlamaConfig()\n",
        "model = LlamaForCasualLM(model_config)\n",
        "model.to(\"float16\")\n",
        "mod, named_params = model.export_tvm(spec=model.get_default_spec())\n",
        "prefill_str = mod[\"prefill\"].script()\n",
        "print(*prefill_str.split(\"\\n\")[3:20], sep=\"\\n\")  # Only show the first 10 lines for demonstration\n",
        "print(\"        ...\")\n",
        "\n",
        "print(\"\\nParameters:\")\n",
        "pprint(named_params[:5])  # Only show the first 5 parameters for demonstration"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 定义优化管道\n",
        "我们定义了一系列优化传递来优化模型。这个优化管道是专门为LLMs设计的。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "[04:59:49] /media/pc/data/lxw/ai/tvm/include/tvm/topi/transform.h:1232: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound\n"
          ]
        }
      ],
      "source": [
        "@register_pipeline(\"opt_llm\")\n",
        "def _pipeline(  # pylint: disable=too-many-arguments\n",
        "    ext_mods: List[nn.ExternModule] = None,\n",
        "):\n",
        "    ext_mods = ext_mods or []\n",
        "\n",
        "    @tvm.transform.module_pass(opt_level=0)\n",
        "    def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:\n",
        "        seq = tvm.transform.Sequential(\n",
        "            [\n",
        "                # Phase 1. Passes on high-level operator graph\n",
        "                # We can enable cublas for further optimization\n",
        "                relax.transform.FuseTransposeMatmul(),\n",
        "                # Phase 2. Lowering to TIR, inherited TVM Relax's official \"zero\" pipeline\n",
        "                relax.transform.LegalizeOps(),\n",
        "                relax.transform.AnnotateTIROpPattern(),\n",
        "                relax.transform.FoldConstant(),\n",
        "                relax.transform.FuseOps(),\n",
        "                relax.transform.FuseTIR(),\n",
        "                # Phase 3. Passes on TIR\n",
        "                relax.transform.DeadCodeElimination(),\n",
        "                # Phase 4. Low-level Optimizations\n",
        "                dlight.ApplyDefaultSchedule(\n",
        "                    dlight.gpu.Matmul(),\n",
        "                    dlight.gpu.GEMV(),\n",
        "                    dlight.gpu.Reduction(),\n",
        "                    dlight.gpu.GeneralReduction(),\n",
        "                    dlight.gpu.Fallback(),\n",
        "                ),\n",
        "                # Phase 5. Lowering to VM bytecode\n",
        "                relax.transform.RewriteDataflowReshape(),\n",
        "                relax.transform.ToNonDataflow(),\n",
        "                relax.transform.RemovePurityChecking(),\n",
        "                relax.transform.CallTIRRewrite(),\n",
        "                relax.transform.StaticPlanBlockMemory(),\n",
        "                relax.transform.RewriteCUDAGraph(),\n",
        "                relax.transform.LowerAllocTensor(),\n",
        "                relax.transform.KillAfterLastUse(),\n",
        "                relax.transform.LowerRuntimeBuiltin(),\n",
        "                relax.transform.VMShapeLower(),\n",
        "                relax.transform.AttachGlobalSymbol(),\n",
        "                relax.transform.AttachExternModules(ext_mods),\n",
        "            ]\n",
        "        )\n",
        "        mod = seq(mod)\n",
        "        return mod\n",
        "\n",
        "    return _pipeline\n",
        "\n",
        "\n",
        "with target:\n",
        "    ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline(\"opt_llm\"))\n",
        "    vm = relax.VirtualMachine(ex, dev)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 准备模型权重\n",
        "我们从 Hugging Face 加载预训练权重，并准备模型权重。预训练权重以 Hugging Face 格式存储。我们需要加载权重并准备模型参数。\n",
        "\n",
        "```bash\n",
        "pip install safetensors\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "IS_IN_CI = os.getenv(\"CI\", \"\") == \"true\"\n",
        "\n",
        "HF_WEIGHT_PATH = None\n",
        "# HF_WEIGHT_PATH = Path(\"/path/to/TinyLlama-1.1B-Chat-v1.0/\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if not IS_IN_CI:\n",
        "    import numpy as np\n",
        "    import safetensors.torch\n",
        "    import torch\n",
        "\n",
        "    if HF_WEIGHT_PATH is None or not HF_WEIGHT_PATH.exists():\n",
        "        raise ValueError(\"Please set the HF_WEIGHT_PATH to the path of the pre-trained weights.\")\n",
        "\n",
        "    # Torch format weights\n",
        "    param_dict = safetensors.torch.load_file(HF_WEIGHT_PATH / \"model.safetensors\", device=\"cpu\")\n",
        "    # Numpy format weights\n",
        "    param_dict = {\n",
        "        k: v.half().numpy() if v.dtype == torch.bfloat16 else v.numpy()\n",
        "        for k, v in param_dict.items()\n",
        "    }\n",
        "\n",
        "    named_params = dict(named_params)\n",
        "    for i in range(model_config.num_hidden_layers):\n",
        "        # Add QKV in self attention\n",
        "        attn = f\"model.layers.{i}.self_attn\"\n",
        "        param_dict[f\"{attn}.qkv_proj.weight\"] = np.concatenate(\n",
        "            [\n",
        "                param_dict.pop(f\"{attn}.q_proj.weight\"),  # Pop the old parameters to save memory\n",
        "                param_dict.pop(f\"{attn}.k_proj.weight\"),\n",
        "                param_dict.pop(f\"{attn}.v_proj.weight\"),\n",
        "            ],\n",
        "            axis=0,\n",
        "        )\n",
        "        # Add gates in MLP\n",
        "        mlp = f\"model.layers.{i}.mlp\"\n",
        "        param_dict[f\"{mlp}.gate_up_proj.weight\"] = np.concatenate(\n",
        "            [\n",
        "                param_dict.pop(f\"{mlp}.gate_proj.weight\"),\n",
        "                param_dict.pop(f\"{mlp}.up_proj.weight\"),\n",
        "            ],\n",
        "            axis=0,\n",
        "        )\n",
        "\n",
        "    # Convert params into ndarray\n",
        "    params = [\n",
        "        tvm.runtime.tensor(param_dict[k].astype(\"float16\"), device=dev) for k in named_params.keys()\n",
        "    ]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 部署编译后的模型\n",
        "模型和权重准备就绪后，我们可以将编译后的模型部署到目标设备上。语言模型的推理包括两个步骤：预填充(prefill)和解码。预填充步骤用于处理输入标记并存储 KVCache。解码步骤用于生成标记，直到生成结束标记。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 标记化\n",
        "第一步是将输入提示进行标记化，并将标记嵌入到隐藏状态中。标记化和嵌入过程与原始模型相同。我们使用HF分词器对输入提示进行标记化，并将标记嵌入到隐藏状态中。请注意，不同的模型需要不同的标记化和提示格式，请参考模型文档以获取正确的标记化和提示格式。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if not IS_IN_CI:\n",
        "    from transformers import AutoTokenizer\n",
        "\n",
        "    tokenizer = AutoTokenizer.from_pretrained(HF_WEIGHT_PATH)\n",
        "    messages = [\n",
        "        {\"role\": \"user\", \"content\": \"What's your name?\"},\n",
        "    ]\n",
        "    prompt = tokenizer.apply_chat_template(messages)\n",
        "    input_len = len(prompt)\n",
        "\n",
        "    # Load prompt tokens into TVM ndarray on the target device\n",
        "    tokens = tvm.runtime.tensor(np.array(prompt).astype(\"int32\"), device=dev)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 创建KVCache\n",
        "在开始推理之前，我们需要创建KVCache。KVCache用于存储注意力层的键和值张量。Apache TVM提供了一个PagedKVCache来存储键和值张量。我们使用指定的参数创建PagedKVCache。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if not IS_IN_CI:\n",
        "    kv_cache = vm[\"create_tir_paged_kv_cache\"](\n",
        "        ShapeTuple([1]),  # max_batch_size=1\n",
        "        ShapeTuple([2048]),  # max_total_seq_len=2048\n",
        "        ShapeTuple([2048]),  # prefill_chunk_size=2048\n",
        "        ShapeTuple([16]),  # page_size=16\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 嵌入\n",
        "下一步是将标记嵌入到隐藏状态中。我们使用Relax IRModule中编译的embed函数将标记嵌入到隐藏状态中。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nd_view_func = tvm.get_global_func(\"vm.builtin.reshape\")\n",
        "\n",
        "\n",
        "def embed(tokens, params):\n",
        "    _embed = vm[\"embed\"](tokens, params)\n",
        "    # Reshape hidden from [seq_len, hidden_size] to [1, seq_len, hidden_size]\n",
        "    _embed = nd_view_func(_embed, ShapeTuple([1, _embed.shape[0], _embed.shape[1]]))\n",
        "    return _embed"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Prefill\n",
        "在运行前向传播之前，我们首先获取一些用于准备的辅助函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "add_sequence_func = tvm.get_global_func(\"vm.builtin.kv_state_add_sequence\")\n",
        "begin_forward_func = tvm.get_global_func(\"vm.builtin.kv_state_begin_forward\")\n",
        "end_forward_func = tvm.get_global_func(\"vm.builtin.kv_state_end_forward\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "由于我们正在创建一个新的序列，我们需要调用 `add_sequence_func` 来初始化请求。此外，我们还需要调用 `begin_forward_func` 来开始前向传播，以及 `end_forward_func` 来结束前向传播。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if not IS_IN_CI:\n",
        "    seq_id = 0\n",
        "    add_sequence_func(kv_cache, seq_id)\n",
        "    hidden_states = embed(tokens, params)\n",
        "    begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([input_len]))\n",
        "    logits, kv_cache = vm[\"prefill\"](hidden_states, kv_cache, params)\n",
        "    end_forward_func(kv_cache)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在我们从预填充步骤获得了输出 `logits`。这些 `logits` 用于通过抽样生成标记。让我们从 `logits` 中抽样标记。\n",
        "\n",
        "在本教程中，简化了抽样过程，选择概率最高的标记。实际上，应该根据概率分布来抽样标记。同时，为了使教程简洁，在 CPU 上执行抽样过程。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def sample_token(logits):\n",
        "    logits_np = logits.numpy()\n",
        "    return np.argmax(logits_np)\n",
        "\n",
        "\n",
        "if not IS_IN_CI:\n",
        "    last_token = sample_token(logits)\n",
        "    output_tokens = [last_token]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 解码\n",
        "预填充步骤完成后，我们可以开始解码步骤。解码步骤用于生成标记，直到生成结束标记。我们使用Relax IRModule中编译的 `decode` 函数来生成标记。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if not IS_IN_CI:\n",
        "    print(\"The generated token:\")\n",
        "\n",
        "    while last_token != tokenizer.eos_token_id:\n",
        "        tokens = tvm.runtime.tensor(np.array([last_token]).astype(\"int32\"), device=dev)\n",
        "        hidden_states = embed(tokens, params)\n",
        "        begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([1]))\n",
        "        logits, kv_cache = vm[\"decode\"](hidden_states, kv_cache, params)\n",
        "\n",
        "        end_forward_func(kv_cache)\n",
        "        last_token = sample_token(logits)\n",
        "        output_tokens.append(last_token)\n",
        "\n",
        "    print(tokenizer.decode(output_tokens))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "py313",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
